{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import cv2\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "from torchvision import transforms,models\n",
    "from PIL import Image\n",
    "from keras.datasets import cifar10\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 加载数据\n",
    "可视化数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = cifar10.load_data()\n",
    "train_x, train_y, test_x, test_y = train[0], train[1].reshape(-1), test[0], test[1].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAATwAAACuCAYAAACr3LH6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJztnWmMJPd53p+3+u7pmZ2d2d3hcne5yysUKcqWFEaW4CAxbAtg/EWBrQS0gUAfBAhwFMDKAZhOgCCHE1hJYOdDLhCQYgExrDCREAmBEkNhZMtBAoo0RUq8uctDJPfe2dm5+qqqfz7McKeep3a6e2dnuoes9wcsdt6u619Vb/+n6pn3sBACHMdxikA06QE4juOMC5/wHMcpDD7hOY5TGHzCcxynMPiE5zhOYfAJz3GcwuATnuM4hcEnvD3GzE6Z2XfM7KqZnTezf2Nm5UmPy/lgYWb3mlnHzP7TpMeyn/EJb+/5dwAuAjgK4KMA/jKAvznRETkfRP4tgKcmPYj9jk94e8+dAB4PIXRCCOcB/E8AH57wmJwPEGb2CIAlAE9Meiz7HZ/w9p5/DeARM2ua2TEAfwUbk57j3DJmNgPgnwD4O5Mey/sBn/D2nu9j44luGcA7AJ4G8N8mOiLng8Q/BfCVEMI7kx7I+wGf8PYQM4uw8TT3TQBTAA4BOAjgy5Mcl/PBwMw+CuAXAfzepMfyfsG8WsreYWaHAFwCMBtCuLb52V8F8NshhAcnOjjnfY+ZfQnAPwOwsvlRC0AJwEshhI9PbGD7GJ/w9hgzex3AYwD+FTYc8j8CaIcQfm2iA3Pe95hZE8BM5qO/B+AUgF8PIVyayKD2Of5Ku/f8MoCHsfGkdxpAH8DfnuiInA8EIYT1EML59/4BWAXQ8clue/wJz3GcwuBPeI7jFAaf8BzHKQw+4TmOUxhuacIzs4fN7BUzO21mj+7WoBwHcP9ydp8d/9HCzEoAXgXwaWxkEDwF4FdDCC/u3vCcouL+5ewFt1Km6BMATocQXgcAM/s6gM8A2NYhD87Nh2MnTl63b3ay1bWHbZ9fX5fnPhi4h/z2Q/Z3E/setn46ZHl+9/qJ3eT6Qxi2viz+yesvXQ4hHL6JI9y0fzWnpsPM7PaHMLNtl914/dwnQ7YY7C/57XmFYcPT/aVpcv3nXr9Ly2Kx+2JntwWANGE7ikpkl8s8VSSyfqlcIdsifnlM0xi8gprRwOVJ3CNbv/v9dm8k/7qVCe8YgLcz9jsAfmbgBidO4vH/8SfX7TTNfY2JRG5wnPD6ur3afdm+n/IHetPy+wtiy/4SvomxOHAatjYw2VeQnekNVLsXs0Mkur8w+NxCkAlPzmXYvQhy7a2fbLPmjff3xc8+9NbADfLctH/NzB7G5774zzOf8DWJ5Euotk6IN2vrOesElci32GR8pSETXio3bW196frP77x7hpZduPAm2efOnya7u7ZKdnv5Ctn1xgGy5w/eRvbq6lWyZw7x8lK9IWPl/ZfLJjavj4ivzdIVThVOYp7A3372JyP5157/0cLMvmBmT5vZ04tXLu/14ZyCkfWv9trK8A2cQnMrE967AE5k7OObnxEhhMdCCA+FEB6amz90C4dzCsZN+1djanpsg3Pen9zKK+1TAO41szux4YiPABiSHxroMV4f6RXTVwB55I/kg6DLdX+qG6gtG0S6Q31Fkh2U9JUn8wZiJq+7smvV//T1pRQNft/JvX7n1tDXO9Zohr1i6/vZMD0sGjLeEdiBf8lrpek5yLr6mi4OkD+Dwa+sUYm/TurfpjpbzLKA1ZsyXJEtxKPrtanrPx+YnqNly8v8yjk7w6+ca+Dss2rC/lBrzZDdEw2tF6+TvbJ+juxGyttXoioY+e4Gvhb9Hr9yd5fZTnb4x9YdT3ghhNjM/haAP8JGhYavhhBe2On+HCeL+5ezF9xSM5kQwncAfGeXxuI4hPuXs9t4poXjOIVh7O0CI9JBRBPRdUVT0dlZJbb89mqLJjJE48ktF90qJ1PpB5nxR6rnyKaqJ+kauZCFIRpaSc9Vrp5qdjp01RBVD1NNUplEFZ7IDM3qlhalGpiVWKdSjS1VYRW8PkxCcfQc5ZqUJbSit3qN7PZahzeP+I96/Q7rZK0DR/jwGR32toXjtKzZZD3wdJnHst7lMJHZo7fz2Nr8F++VpUWy69O8/7TMIVorq7J+tUV2P+mT3eus8f4S1gxLdY7zKzV2NnX5E57jOIXBJzzHcQqDT3iO4xSGsWt4LJQN1rVyEplqLlAdSmLNNE5qwEhu9MnA3FjcSFfbfvtoyNnl9cchua96cpqaNiT5VuOYynJxVPPTaz0sFY6CEMdEZCnq0Vasm6Z6WSK6aFk0OpXscvdA0/XElvGkPdalbI1j45pljk2rBtb0Ivl21kqse1nmGifgZUmNz70uGl69wqlc15Y4jq7f5bGUq6yhVWp8sXpgza0yzesnEoOY9NgOotmZ3Juoztequ8rXdlT8Cc9xnMLgE57jOIXBJzzHcQrD2DW8rC4yOFM1r8nlNLsh9cRUpxoWNxdSHYGU75H1N2pUZsYn5aIiOv7guLnctdDcVc0rzp2b6FPyqyyJJTdXVtDjpUPi7jTvWHN5dyGX9uYxIKpmjtsXXUttlpEQRXyNalXRkfQaSU05rRm33Fsiu91l3ancH6wBzkqxjSRmDTCKMttLrmsj5ZO79/gC2U0RoF89w9sv9jjXVjXeJPC1jBp8reK2XFyJcazMlMTmuL445uWdRY4LTNf5+KPiT3iO4xQGn/AcxykMPuE5jlMYxqrhGfI6SJZhPSJyml5uXxJ3pdsPOaLmn6a5uDzZfy7Ob/tc3byiNSzqUGPAVNMTO6f56diYvEY4OAZx0LkBeY1vErm0wEYs3nuUJPasEnFsWFlqAsYxx54F0WQ1NlFzczWub6rJBUlLC3Wyr13ifNPeGufaludE55K4vTQTB2g11VDJxNUV1g/XOqwHxqLJ1Rs81kaDr921VHJftcdET2r9iYaXSsxjVOcP+tc4jziV2oG1g1ISfsQGAv6E5zhOYfAJz3GcwuATnuM4hWHMcXjc0yJXn07WTkWJ0rZ6uVzGXHky1bl4eZSrx6c70PHoACUuUPNRqafF4P4cwxpu5PTN3P4Gx+Vpnqjm2uZiILX/g/5q1HJ52g8E48cAVDLjrlRrtLzXG1wD0CRXti+aXkk0u2qZ959KPmm9ybrX1DT3eVD/kTYOWFpknawxzTdlKpOvGktM4NKy6INd1sTqEmN4aIbH1q2xhnfuwlleXmbNr9GQfhxy7klXcrm1zac4ZNrh5fVpHg/40o+MP+E5jlMYfMJzHKcw+ITnOE5hmGg9vJyupWsOyRdVHWt47NeQPrhDdLFcC4ygGp7UkMvsLtfHNSfZ6bFFv8zFwclYcjGLvDynCer2emnyIiNZqulpnNWkfpXGWf+SPiHlkuhEUoMtkeTaXF8Qicvr99iOU6kJJ7nWlcA6WtTj/NBp6UORRFNk92O+xssrW8frJxzDV5G84GaFt12UuLvbjx4l+4WXXya7F/hcG5L7GgW+4fESa3CJ3IvEeH/a0yUKrAH22zzeeM1zaR3HcQbiE57jOIXBJzzHcQrDBDS8LS1hqOaWC4u7ufzRofaQ2De105zOJsFtGjdIGtDgHqb5Q+ea5A4cW+5cdGTDevIOGV++Z4WYQ2sL7j0BASGjNfV6HEdXk96p1QqPsSt9YNNE8kFFo11eZt1M+zo0S5zvWV2Xe1bm5SXJh51u8ddz8SrH5b326lYCadN4LLfPcKBaJeVzPyL1ENc7rF/2pOdEo8l6YqPEeb3t1TbZSVv0TdE/U4nD63VFP62whqfb5/qRjIg/4TmOUxh8wnMcpzD4hOc4TmEYs4ZnCBnxJ0n3tndpaZhGpzKZ6FB9TTiVRqGR9m4VnauUOYDWG1M0j1OTVYP2z5C4J21iEeTapnL81LSfwmCNL5XxqX5pmkybb4y755QiYCYjNQW9pKIbaU+KVGLN+hKX15Y+Dav9ZbI7xpphN2EdagHzZFemDpD9yukXyQ5dtk38rdTfGm+rIT1LjOPkGjXW3BZmD5MdVVpkz01xLb/vv8pjWVrjGMIg/hhVZWrpD46pVVvDOiui2YWe3tzR8Cc8x3EKw9AJz8y+amYXzez5zGdzZvZdM3tt8/+DeztM54OM+5gzLkZ5wvt9AA/LZ48CeCKEcC+AJzZtx9kpvw/3MWcMDNXwQgjfN7NT8vFnAPzc5s9fA/DHAH7zZg8+rC/trZKrd6Y6QU7nGtYTQ+vnaf7f9rFtOU1Ma+kNi5MbsO/3Rru9NTxXN4fW19M4Ox3PkJjIQeyWj5UiQ2tqy6X7fdZ92iL7rK2yJteXvrVd6SPbkVi1RK5hEkTXBOtmJn1rz7zwDNk/fu6HZNeMfeTY7ZzveufJO67/3JIeFNU6x83pDYlqrNFpD4v7br+D7K7c3/975nmyV9f5WtWaUr9Om1hoj9++fhd59aTLMY7xmDW8hRDCuc2fzwNYGLSy4+wA9zFn17nlP1qEjV/12z6cmdkXzOxpM3t68crlWz2cU0AG+VjWv5ZXVm60iuNcZ6cT3gUzOwoAm/9f3G7FEMJjIYSHQggPzc0f2uHhnAIyko9l/WtmevpGqzjOdXYah/dtAJ8D8Dub/39rtM0CLKt75erN3ZyKZ7k+tKqxDZ7Ph+pYWo8v1whCdAc9fmbzXF/YIf0zSqq5aO091cxyPTA0Dm6wxpbvwTt4/6n2GRVNsnRTKt4NuXkfCwEhzuTSdjkubl36OqSpxCbGbbFZN0rB63ckbs/k61QXO4l5+8UL3Ex1ps7XbGHuCNmNGi+vZur7af+Oeo3tToevxcoKN9BYW2O7NX+c7EMzx8i+fY7f1t5M+VysIjGO0lc2Ej2zvcbXuizfmIrkMddKHOO4dG60t8dRwlL+EMD/A3Cfmb1jZp/HhhN+2sxeA/CLm7bj7Aj3MWdcjPJX2l/dZtEv7PJYnILiPuaMC8+0cBynMIy/Hl5Ge7JhzVFvYl83tFVT08MNLgE3NO4u19Mi1fWzI9FGr1LfS/thyLFKurkeW3JZNY5P6+HpeJKgY9eLIzFnqvlJXqr2cB0HSQCWO1vH7XTYveOUY9MSSI02Y10orXAsWYg41ixts+7VkZ4WNXmeiERHa1Skz+0c59a2GtJ8Ve5hN1NDrtnkZW05Vr8veb4d1hNNGheHBuuZb5y7QPbKKu+veYD/YNSVmMVazDGJ+k0tp3KvxOHL+t3tei6t4zjOQHzCcxynMPiE5zhOYRi7hpfVxTQMTnWpYT0vVANUzU37QIShzVhF9xJdqiwD1rL6qmtla8SVJRe1J0NNpXibnltJNTotP6d5wHJuQfavfW41zk41vXzq7uA+u5Ooh5cEw1p/ayCpsa6USg+LBKIzVVlnqsg9q8Ws+c1KicOe9F1AzDpX7yzXlGvVWaPrS76p5i9XKqLpZfy5vc5ZJok4iH6Xej3poVvhc+/LuVy6fInsFclqOXyca/0FCSQtyfHjmK/91RLroX2p11iT+ovRNE9dr2I0/AnPcZzC4BOe4ziFwSc8x3EKw1g1vI2OFlvv8qpRaBGsnI6kDKtfl8svHZwvqjqUSn5rq9z784pUf9F6almdq9YcnNjemuKeAkmiuYccE6aaYCz6kmo2+pstFxeYq/0n6+f0Ul7DpKeG5taOAwsxSvHV63YuH1N6qYaSXjPWlbQmW5DYQqvz/spyj2LpabGa3sbbry+SXRJdC8ngfOis/7fXWQOLtceE5J5qbmutxT1yoybHBNab7J/RNdYTTyx8VMbKx08S/m70exzTONvh/iCqaVdkfx2pVTgq/oTnOE5h8AnPcZzC4BOe4ziFYexxeFmdTWPFVINTe9C+NhicS5urASdxdrp5JD0FzrzyAtlPPfUU2dncRgDo9bZ0i770O/jpj32M7I88+CDZquFNHWTNJFGBSXNlVd8UPVN77ibx4NzeYXF92pdW+4qOB6OadG3tQVGROErZOorkmvRYF2tIjbkoknp6fd2eY92qLe4V25yeJdvWrpLda3Mcn/rE2tpWfb+SBIVWa5w3nL8/EmM6xWPpG+uTx4+dJPvCBc6tbcxw7b4p2V8iemScaL08PtdY85x7bHdXWE8fFX/CcxynMPiE5zhOYRjvK22QUuXaOnBYSfXc/ganoqWmr2W8+bBX5iDpNQuH5sg+efx2siN5DbyyuBV20JO0obIM5uUXue3dPffcK+vnRkdWrvyT2Po6pOWmopLmqrGZ5NpM8oByLTF3venmcEqlCloHMq9WdT7nvoRiSEWkXLodIn4tXJMS7ZVUrwnvv93lEkvVspRhn+He4usrS7I/9u+O7C97z2ZqvK9qhb/awaTcvHx3eon4kxzr0EEOU/nQ/ffw8cCvmJVc6Tb5rsu1sApf676UT0tr7J/V5s76svsTnuM4hcEnPMdxCoNPeI7jFIbxl3jPkI9c2FvdRysaaVl0MdGTMti1Kl+u++69m+xp6Yv6Z3/2zPWfqy3WHNbarJGo/jgnmkku9Us1s0jLAWl5+2EhPoM/0DQ9JdUwgwmklpWigIONTOvCiM+5VOV70JheIDtI6MTly+fI7snyirQatHSwDtbPRVGxbhXOcavDUpM146R9nuypqa3QkakDXJ4p6bHvauphScpBdaQlZWWNNTktT3b3cW7j2BdROIo4zAR9SZtL5bsnYTDosz91Zf047Gzq8ic8x3EKg094juMUBp/wHMcpDBNILdtC05lyJd1Fl8qlU+XK5Ui6E9jW/avOlcr+L15kDefHz/2Q7I60wnv7Jz8hu5TReO68hy/12XfPkv2pT/0s2RrTl0jpqVIkbfVEM0vl2lYkzk7CrvKaW64Dplxb+VWpMZTahnIcGIxS4ppVKQ/W4BJHVmEbEWturZaUf0p5ea3K16QnIt16ie/5csKxZ+o/qpO1jnMsZhc8nubCVhzo9MkP0bKVN3/EY7vC/taYZo24Ms+pY2mHz7V7jUuhWcxjXZfyU311p8AaoQrqQePutFScxBH2kp3lLvoTnuM4hcEnPMdxCoNPeI7jFIYxa3gBSUYryulG8l6fSjCYllyPdLoOWoKGdS/dXkPLEsmdnT8s+XqSn1gCazLT8xwLNT+/FUfVS1jDOHuONZUjC1L+W9rS5XJTc/XHh5Vo10RivdaDy0kFWZ4vly/Lo/HH4cEMVt7Skupl1pXWRANLA5dkr9XYTnvcijBe4TLk8RDdqSo3YS7i9ReTi2SXyzy+apPHM3WIfaQ2u2Unddbkaoc4Ti5Zu8JjFXeYnhX/k1zXTpljBHuX2G4dPkx2t8rlobRcWazloBK2065o0HIv65HH4TmO4wxk6IRnZifM7Htm9qKZvWBmv7H5+ZyZfdfMXtv8f2flC5xC4/7ljJNRnvBiAH83hPAAgE8C+KKZPQDgUQBPhBDuBfDEpu04N4v7lzM2hr4IhxDOATi3+fOKmb0E4BiAzwD4uc3VvgbgjwH85uB9cV22vE7EJFJfTGu6laFxd5rPxzqY6hYa63ZgZobsV157jewjR1kXWVvjVnHTs6zhra5ulQg/f5Y1u9Nvsgby9f/6DbL/2mcfIbtWFX1J4+5EMuv1VVOzgbbGOOZSZ+Vaae5srNvfIFP6Ruymf5VKEQ5m7qHGWbav8f2yeJ3srmi+iHn9mWnWtfqJxDZ2WadNReNrtSTub5pbI67O8kNsvc5fz9YUl4g/MLV1/N7aG7QsNs5drR/gvNxVaTEar7B/Bskzrh7gsakmWEn5Wk0d5jzzjvG567VRb4mlPH6/p7m4O4vzvCkNz8xOAfgYgCcBLGw6KwCcB7CwzWaOMxLuX85eM/KEZ2YtAN8A8KUQAv25KmykMNywHIeZfcHMnjazpxcXF2+0iuPsin9dW9pZYxenOIw04ZlZBRvO+AchhG9ufnzBzI5uLj8K4OKNtg0hPBZCeCiE8NDc3NyNVnEKzm7514FZLanlOMxQDc82hLavAHgphPC7mUXfBvA5AL+z+f+3hu0rhIBOf/t3b811jSo6PGmDJ+/xcU9yE0sSdyXz+1uio128eIns1XXWeHoau5YObj0X1bY0mtuOnaBlx0+xxtFosX5YbXKttERzEyVOLw7Sxk6uZU1yHfNtFkXjy+Ud8/FV/4zSwXnK27Gb/mUIsEzOptYMXF16l+znnv/fZKeiyX34Xu5ZcvAutqt11uDqDbbb4j9VaRPZXeM2kLOHj/L+ZtgHUsmnPtDaOl4aSV5vg32fRwJcu8Ia3OpZ1qvDHLeMLFdZg6tNs7/2l94hu/POi7x/yVuuSxvHptT+q1dZL61W+dqqvj8qo0Tv/SyAvwHgx2b27OZnfx8bjvi4mX0ewFsA/vqORuAUHfcvZ2yM8lfa/4MbFSfe4Bd2dzhO0XD/csaJZ1o4jlMYxppLu95u45nntup0ae6qxtlVpIdErSL5pSlrGlMNfu+PItYxQsTLn3nmWbKfffY5spdWOJdy4eQpso9LXf/Tp0+TPZ/Jrb3jjjto2d333kf2KdH0LlxijaXbH5zr2pU4Je3XUZZ6eJGpxiYPWSLa9WPVXgdrfJMgTQO63cx1EB1xZpZ7jpy/wJqe9ln4yP2cX5qK7pn2B9fD0z7Iqyv8V+Q1qTnXmuXIm1npkYKu6q5bP1erUruvIn1fRQSu1TmuM5b+GzNV0Q9Fg+smvH57lVXCtMv9N4LU3+t0uQdvvMr6eaXMGqHV2a7UWNMbFX/CcxynMPiE5zhOYfAJz3GcwjBWDS9OYixe23p3b0jcUln6fJYlDs+kxtop0cVmZ1jzqEsPgzNvcKzQrASq3n33nWRfXeY4qZkjrOk8+eQPyH77Hd5/nImb+pVf+WVadvAgxx29/NLLZF84zxqe9kTV3NZ1ifmqVDjuTpNtSyIwJbJ/01xb0fBMNELVYyeh6SVJgsVrWzpZXeLktGfEnNz/1TXWQZtTsn2Jr0nc5fWDaHpB4vpSyQ+dmmYfOHSEY9NKZamvJxp1c2bLv7XHSEfyeut12Va+e4uyfmua83bTSPxJ6tf1a7JcNOOaXMs41vp47L9Ly5I1I/XvItEoR8Wf8BzHKQw+4TmOUxh8wnMcpzCMVcMLgftV9tf4vf3gQa65VatzbNHCIV5eEY1veZlje1ZWuUYXjDWWP3cfx74dO8Ya3dIKa3hX11nn+MRf+PNk/9RHPszbL22Npy7nMjvLcUXtNY5rWlvl/gmQmv6JxJhpf49EemgEyXVVTXBYrmw8RMPT5bkew2MgICDOaJXnLrMOqrrkwsIRstPzXM2nJLrV0lX2B0g9RtWgazWOG62WJLYt5f23JRfcRBMsiS4blbb2rxpeQzQ7kRNz96cl372DB9k/u6I/Tksu7LXA13LpCufiprHEEAbR+MocF1huiS33rt/TezEa/oTnOE5h8AnPcZzC4BOe4ziFYbx9aS1ClNGirkhNrhXRsc60WQeolVh3OCQ6g8aWqbBRb3Kcnsb5JaIzqC6lvx3uOM71y0ol1myymo7mCfe6nAd8+23c1/Ptt7nHgMYxqWi3vMyaX080lyCaifa8KJV57Bp31+8P1vC0B0YYsafFbmJmFH/Y7fM1KZX4fh8+xD1IVlbY/+o1ifWSWEPVBLtSr051snXpqZFIbveVy9xnYl5iNcX9ETLH66aDj51KrqzWbpy/nX25XNF+MxIjK5pbr8XfrbjNMYqzJ06Rvd5m/f78efb3svRw6cfsbxpnOCr+hOc4TmHwCc9xnMLgE57jOIVhvBoegJDRfuYOsW7VFw0k6XI+XQi8vNHg9/xIel5Eks+XgLdfW2dNpd/j5d2e1OuTuKmeaCqq4WV1lLJoZCXpMVGV2n13n+QeGHqsWHJjE4nhCtJjVVuJmIw1lxsr+1dNLo4Hx6ClI/a02E0MhnLYGucRyZWtSZ8E9bflJV6/Wed8Uu1JofXtaqIRijtidYnjRKck9u3wLGt2jRr799IljhNcXNyqIVc/wHp2Tba9Jh0DqzWJC5XtEbMGV6vw+gkG94UNEqNYb0h9xgrHJPb6fO0jmZpCkB4uLR7PqPgTnuM4hcEnPMdxCoNPeI7jFIaxanhpmpJuprqP9mnQfFOLpe9sxO/1PalPVi+zZlPJ6WjaA4PHm9O1Yt5e+9JqLFq2T0QSiz4oY11d4XMri8ZXn+Fr0ZOYsCPznNuY9jnuakXWr8j+DZr7KrmPkfS46PJ4E9FYNG5vPASUMufR62r+J+tG0iIFtx3iaxhJXOeK5Favr/M1aJbZAZZXOY70jbffJvtUlf3r4BzHBZrooNeWOG41G8s2L2NPpcfJ4jnuY35wnvXCclXydMEE6QMbVKCE9kBhf19d5msRlVmDm2nxvWk0WD8147mi3eFrPyr+hOc4TmHwCc9xnMLgE57jOIVhzBpegk5Gw8vlCsr6qsEdv4P7wNZEd3jppRfJfvfsBbIbohNk+8YCQKXE+apWlfzXXOyR9CxIto8DLIseGKQ/hzXY1vpjoc/6UST17UqiH81OsQbSWec8zbTHPXdV35xv8bW4TWrHBdFsLpzn/SfJzvqG3ippVlsSUfbyVY5Fk0sGE12q3VbNjnXRVHTL1RLrZmf7r5Ndkh4bicimmv/cFQ2y1pTes9Wt8ZYlRvXqVY4RrIpgedsJrv1oqqdL3KX25DWxI6k1qX1yD85InF1Zpx4+XlP0Vu2LXC17HJ7jOM5AfMJzHKcw+ITnOE5hGKuGV6lUsHB4Swtqr3Euq77XP/gg94i44zjrDivLrEM1m9yHdr3DmsvpN1hTee3VM2RrXKD22Jia4v1rzbGm6GaVTO0/kThytfcaou90JM6o3Wc7lTi55asc53TkCNc3a4l+qX1HTxxdIPvYUdbsVANSTefyZdaMVpb52v9L7D0BQJLpHVuVPiB9Ec16EkvY77KGp5qa9vrtdFlXemv9NbJrs+xPDzQfIFt7+apm9+5ljp2bkvqN2Xu6JPUQL55n/bopecVB9M10SM+SsuRe1/RaRFLrL9Fca9nfkEct1ajLGiZali/UiPgTnuM4hWHohGdmdTP7gZk9Z2YvmNk/3vz8TjN70sxOm9l/NrOd/dnEKTTuX844GeUJrwvg50MIPw3gowAeNrN5i7g1AAAFxElEQVRPAvgygN8LIdwD4CqAz+/dMJ0PMO5fztgYquGFDaHqvYCkyua/AODnAfza5udfA/CPAPz7gftKA+kiGmvWbbPu8+yzPyT7hR/z/rR3qvaoOHnqFNn3338/2aurHGf1/PPPk/3666z5Xb3K9cxq0vNANZ6s3ajwulWpL6ZxS7qvJFfrj8+1VOLt75BagXfcdpLsEyc5pvGA9Myoi2anMWpdydWs1binwXKLexZsx276VxoC2plxNUw0PMnvjST5OUjubCx9R3qS29oGa4DnO2+SfTfuIltj1yBxnRrbOCt9IkoSe9bL6Ljnzp2jZdckd/XEHMe8nr3EeblJzOd27Kj0uJAeFhUdS4c1w9U1/m5Vq+xPlYj9tyu1KNuSax5FWo+Rl4/KSBqemZXM7FkAFwF8F8AZAEshhPfu2DsAju1oBE7hcf9yxsVIE14IIQkhfBTAcQCfAPChUQ9gZl8ws6fN7OnV1ZXhGziFY7f8a2XF/csZzE39lTaEsATgewA+BWDWtmq2HAfw7jbbPBZCeCiE8FBLHtEdJ8ut+tf0tPuXM5ihGp6ZHQbQDyEsmVkDwKexISh/D8BnAXwdwOcAfGvYvgIC9TqYEQftSq7i2XNcP2x9hTU01eAqooP9yZ/+KdnVIZqb6mjHjvFbVK/3Ktnaw6LV4ji9bOxSKvqRxrEty7lpHJTWv2t3WP+86857yL4qcXkak1gRTWX6Ltb4ItFYVONZvMLjrUv/h/l5jmHcjt30rzjp49LSVuzaoaqMQWK7Kg32h1W5pu2u9vbla/Bu4Di59R7rlgGy/z7HqkFqDMYyvukp/n4kxj6wtrp1T0vSr+N26Yly4IDsS2pJLq3z2C8ssf+06qyZTTU5rrM6zfbxk3eSrf4cST3GSlXj/qReo+itDe0ZPCKjBB4fBfA1Myth44nw8RDCfzezFwF83cx+G8APAXxlRyNwio77lzM2Rvkr7Y8AfOwGn7+ODb3FcXaM+5czTjzTwnGcwmCaD7qnBzO7BOAtAIcAXB6y+qTYz2MD3r/jOxlCOHyDz3eN94l/Aft7fPt5bMAt+tdYJ7zrBzV7OoTw0NgPPAL7eWyAj+/9MoZB7Ofx7eexAbc+Pn+ldRynMPiE5zhOYZjUhPfYhI47Cvt5bICPbxT2wxgGsZ/Ht5/HBtzi+Cai4TmO40wCf6V1HKcwjHXCM7OHzeyVzaKOj47z2NuM56tmdtHMns98Nmdm3zWz1zb/Hy1Ham/Gd8LMvmdmL24Wx/yN/TLG/Vi40/3rpsdXPP8KIYzlHzbazp4BcBeAKoDnADwwruNvM6a/BODjAJ7PfPYvADy6+fOjAL48wfEdBfDxzZ+nAbwK4IH9MEZsNBJtbf5cAfAkgE8CeBzAI5uf/wcAv+7+5f61X/xrnCfwKQB/lLF/C8BvTepmZ8ZxShzyFQBHMw7xyqTHmBnbt7CRXL+vxgigCeAZAD+DjaDQ8o3u+R6Pwf3r1sf6gfevcb7SHgOQLX+yX4s6LoQQ3isfex7AwqCVx4WZncJGzumT2Cdj3GeFO92/boGi+Jf/0WIAYePXyMT/jG1mLQDfAPClEAL145vkGMMtFO503L+GsRf+Nc4J710A2SJd2xZ1nDAXzOwoAGz+f3HI+nuKmVWw4Yx/EEL45ubH+2qMYQeFO/cA968dUDT/GueE9xSAezf/ylIF8AiAb4/x+KPybWwUnARGLDy5V9hG1cOvAHgphPC7mUUTH6OZHTaz2c2f3yvc+RK2CneOe2zuXzdJIf1rzOLjL2HjL0FnAPyDSQqhm+P5QwDnAPSxoQd8HsA8gCcAvAbgfwGYm+D4/iI2Xid+BODZzX+/tB/GCOCnsFGY80cAngfwDzc/vwvADwCcBvBfANTcv9y/9ot/eaaF4ziFwf9o4ThOYfAJz3GcwuATnuM4hcEnPMdxCoNPeI7jFAaf8BzHKQw+4TmOUxh8wnMcpzD8f7mwd55N9WU2AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 360x360 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(5,5))\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(train_x[100])\n",
    "plt.title(train_y[100])\n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(test_x[100])\n",
    "plt.title(test_y[100])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 编写dataloader\n",
    "数据生成器，用于训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataGen(Dataset):\n",
    "    def __init__(self, data, label, transforms=None):\n",
    "        self.data = data\n",
    "        self.label = label\n",
    "        self.transforms = transforms\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        img = self.data[index]\n",
    "        tag = self.label[index]\n",
    "        if self.transforms:\n",
    "            img = Image.fromarray(img.astype('uint8'))\n",
    "            img = self.transforms(img)\n",
    "        return img, tag\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.data.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "traingen = DataGen(train_x, train_y, transforms=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),\n",
    "                                                                    transforms.RandomRotation(degrees=20),\n",
    "                                                                    transforms.RandomVerticalFlip(p=0.5),\n",
    "                                                                    transforms.ToTensor()]))\n",
    "trainloader = DataLoader(dataset=traingen,\n",
    "                         batch_size=32,\n",
    "                         pin_memory=False,\n",
    "                         drop_last=False,\n",
    "                         shuffle=True,\n",
    "                         num_workers=10)\n",
    "testgen = DataGen(test_x, test_y, transforms=transforms.Compose([transforms.ToTensor()]))\n",
    "testloader = DataLoader(dataset=testgen,\n",
    "                         batch_size=32,\n",
    "                         pin_memory=False,\n",
    "                         drop_last=False,\n",
    "                         shuffle=True,\n",
    "                         num_workers=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "img shape torch.Size([32, 3, 32, 32]) tag shape torch.Size([32])\n"
     ]
    }
   ],
   "source": [
    "loader = iter(trainloader)\n",
    "img, tag = next(loader)\n",
    "print('img shape',img.shape,'tag shape',tag.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 等价于删除某个层\n",
    "class Identity(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Identity, self).__init__()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return x\n",
    "\n",
    "\n",
    "def resnet18():\n",
    "    res18 = models.resnet18(pretrained=True)\n",
    "    res18.fc = Identity()  # 最后一层全连接删除\n",
    "    res18.avgpool = Identity() # 全局平均池化也需要更改，原始的224降低32x，现在32降低32x成了1\n",
    "    return res18\n",
    "    \n",
    "    \n",
    "class ResNet18(torch.nn.Module):\n",
    "    def __init__(self, categroies):\n",
    "        super(ResNet18, self).__init__()\n",
    "        self.resnet18 = resnet18()\n",
    "        self.fc = torch.nn.Linear(512, categroies)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.resnet18(x)\n",
    "        x = self.fc(x)\n",
    "        return x\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ResNet18(10)\n",
    "# 交叉熵内部做softmax，所以模型的最后一层不需要加softmax激活函数\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "# 做微调，那么最后一层全连接学习率设大一些，base model则设小一些\n",
    "param_groups = [\n",
    "    {'params':model.fc.parameters(),'lr':.001},\n",
    "    {'params':model.resnet18.parameters(),'lr':.0001},\n",
    "]\n",
    "# adam默认是1e-3的学习率，如果没有在groups中设置的话，则会是默认，设置了的话就会更新为设置的\n",
    "optimizer = torch.optim.Adam(params=param_groups)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 10])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval() # 模型训练阶段如果含有bn层的话，是无法使用batch_size = 1\n",
    "test = model(torch.rand(1,3,32,32))\n",
    "test.size()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.21it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "1/10, loss:0.6325461864471436, training Accuracy:0.54392\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 151.03it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.6565\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.22it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "2/10, loss:1.0968945026397705, training Accuracy:0.64764\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 147.35it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.6634\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.32it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "3/10, loss:0.9575943946838379, training Accuracy:0.6843\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 144.13it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7191\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.23it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "4/10, loss:1.037312388420105, training Accuracy:0.70872\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 156.25it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7296\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.19it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "5/10, loss:0.4431699812412262, training Accuracy:0.7225\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 151.49it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.731\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.13it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "6/10, loss:0.9629826545715332, training Accuracy:0.73672\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 143.64it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7434\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.23it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "7/10, loss:0.4655568301677704, training Accuracy:0.7488\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 155.48it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7664\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.10it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "8/10, loss:0.8745599985122681, training Accuracy:0.75936\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 141.43it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7675\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.11it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "9/10, loss:0.5323635339736938, training Accuracy:0.76838\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 145.87it/s]\n",
      "  0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7637\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1563/1563 [00:43<00:00, 36.26it/s]\n",
      "  0%|          | 0/313 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## taining ##############\n",
      "10/10, loss:0.40657275915145874, training Accuracy:0.77724\n",
      "############## taining ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 313/313 [00:02<00:00, 145.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "############## testing ##############\n",
      "testing Average Accuracy is 0.7709\n",
      "############## testing ##############\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "model = model.to(device)\n",
    "for i in range(epochs):\n",
    "    train_correct = 0\n",
    "    train_total = 0\n",
    "    for (img, tag) in tqdm(trainloader):\n",
    "        img = img.to(device)\n",
    "        tag = tag.to(device).type(torch.cuda.LongTensor)\n",
    "        # print(tag)\n",
    "        output = model(img)\n",
    "        _, predicted = torch.max(output.data, 1)\n",
    "        train_correct += (predicted == tag).sum().item()\n",
    "        train_total += tag.shape[0]\n",
    "        loss = criterion(output, tag)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "    print('############## taining ##############')\n",
    "    print('{}/{}, loss:{}, training Accuracy:{}'.format(i+1, epochs, loss.item(), train_correct / train_total))\n",
    "    print('############## taining ##############')\n",
    "    \n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for img, tag in tqdm(testloader):\n",
    "            img = img.to(device)\n",
    "            tag = tag.to(device).type(torch.cuda.LongTensor)\n",
    "            output = model(img)\n",
    "            _, predicted = torch.max(output.data, 1)\n",
    "            correct += (predicted == tag).sum().item()\n",
    "            total += tag.shape[0]\n",
    "        print('############## testing ##############')\n",
    "        print('testing Average Accuracy is {}'.format(correct / total))  \n",
    "        print('############## testing ##############')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
